from io import FileIO
from pathlib import Path

import pytest

from ..code.exploit_mitigations import (
    AnalysisPlugin,
)

PLUGIN_DIR = Path(__file__).parent.parent
FILE_PATH_EXE = PLUGIN_DIR / 'test/data/Hallo.out'
FILE_PATH_OBJECT = PLUGIN_DIR / 'test/data/Hallo.o'
FILE_PATH_SHAREDLIB = PLUGIN_DIR / 'test/data/Hallo.so'

FILE_PATH_EXE_CANARY = PLUGIN_DIR / 'test/data/Hallo_Canary'
FILE_PATH_EXE_SAFESTACK = PLUGIN_DIR / 'test/data/Hallo_SafeStack'
FILE_PATH_EXE_NO_PIE = PLUGIN_DIR / 'test/data/Hallo_no_pie'
FILE_PATH_EXE_RUNPATH = PLUGIN_DIR / 'test/data/Hallo_runpath'
FILE_PATH_EXE_RPATH = PLUGIN_DIR / 'test/data/Hallo_rpath'
FILE_PATH_EXE_STRIPPED = PLUGIN_DIR / 'test/data/Hallo_stripped'


@pytest.mark.AnalysisPluginTestConfig(plugin_class=AnalysisPlugin)
def test_check_mitigations(analysis_plugin):
    result = analysis_plugin.analyze(FileIO(FILE_PATH_EXE), {}, {})
    summary = analysis_plugin.summarize(result)

    assert result.model_dump() == {
        'canary': False,
        'clangcfi': False,
        'nx': True,
        'pie': 'enabled',
        'relro': 'fully enabled',
        'rpath': False,
        'runpath': False,
        'safestack': False,
        'stripped': False,
    }
    assert sorted(summary) == [
        'CANARY disabled',
        'CLANGCFI disabled',
        'NX enabled',
        'PIE enabled',
        'RELRO fully enabled',
        'RPATH disabled',
        'RUNPATH disabled',
        'SAFESTACK disabled',
        'STRIPPED SYMBOLS disabled',
    ]


@pytest.mark.AnalysisPluginTestConfig(plugin_class=AnalysisPlugin)
@pytest.mark.parametrize(
    ('file_path', 'check', 'expected_result'),
    [
        (FILE_PATH_OBJECT, 'pie', 'REL'),
        (FILE_PATH_SHAREDLIB, 'pie', 'DSO'),
        (FILE_PATH_EXE_NO_PIE, 'pie', 'disabled'),
        (FILE_PATH_OBJECT, 'relro', 'disabled'),
        (FILE_PATH_SHAREDLIB, 'relro', 'partially enabled'),
        # TODO: Test clang CFI: enabled
        (FILE_PATH_OBJECT, 'nx', False),
        (FILE_PATH_EXE_CANARY, 'canary', True),
        (FILE_PATH_EXE_SAFESTACK, 'safestack', True),
        (FILE_PATH_EXE_RPATH, 'rpath', True),
        (FILE_PATH_EXE_RUNPATH, 'runpath', True),
        (FILE_PATH_EXE_STRIPPED, 'stripped', True),
    ],
)
def test_checks(analysis_plugin, file_path, check, expected_result):
    result = analysis_plugin.analyze(FileIO(file_path), {}, {})
    assert result.model_dump()[check] == expected_result
